from getinput import GetUserInput
import os
import logging
import csv
import re
import sys
import socket

logger = logging.getLogger('root')

responses = GetUserInput()

class GetNetworkHosts(object):
    
    def __init__(self, hlist=None, hfile=None):
        self.hosts = set()
        self.dupes = []
        if hlist:
            if isinstance(hlist, list):
                self.validate_host_list(hlist)
            else:
                logger.warning('HOST_LIST is not a list.')            
        elif hfile:
            if not self.host_file(hfile):
                logger.warning('HOST_FILE is not valid.')
        if not self.hosts:
            self.options()

    def options(self):
        logger.debug('Giving options.')
        c = responses.ask(
        """Please choose from one of the following options:\n
        \t1. Load targets from CSV.
        \t2. Scan network with NMAP.
        \t3. Enter single target name or IP.
        \t4. Build hosts list.\n""", 'multi', ['1', '2', '3', '4']
        )
        logger.debug('Option chosen: {}'.format(c))
        print ''
        if c == '1':
            if not self.host_file():
                self.options()
        elif c == '2':
            self.nmap_scan()
        elif c == '3':
            host = responses.ask('Enter IP address or hostname.', 'raw')
            if host:
                if self.validate_host(host):
                    self.hosts.add((host, host))
                    return
            self.options()
        elif c == '4':
            self.build_hosts()

    def list_found_files(self):
        import time
        files = [f for f in os.listdir('.') if '.csv' in f]
        print 'Choose a file:'
        for c, f in enumerate(files):
            print '{}. {} -> {}'.format(c+1, f, time.asctime(time.localtime(os.stat(f).st_ctime)))
        print '0. None of these'
        choice = responses.ask('', 'multi', [str(e) for e in range(len(files)+1)])
        print ''
        if int(choice) == 0:
            return False
        return files[int(choice)-1]

            
    def host_file(self, hfile=None):
        if not hfile:
            hfile = self.list_found_files()
            if not hfile:
                hfile = responses.ask('Enter filename', 'file')
        logger.info('Checking host file: {}'.format(hfile))
        try:
            parsed = [host[0] for host in self.read_from_csv(hfile)]
        except:
            logger.error('Could not parse file.')
            return False
        else:
            logger.info('File parsed.')
            self.validate_host_list(parsed)
            return True

    def build_hosts(self, hlist=None):
        if hlist:
            self.hosts, self.dupes = self.validate_host_list(hlist)
            return
        hosts = []
        logger.info('Building host list.')
        while True:
            hname = responses.ask('Enter host name or leave blank and press Enter to stop.\n', 'raw')
            if not hname:
                logger.info('No name given; breaking out.')
                break
            else:
                hosts.append(hname)
                print hname
                logger.info('Added name: {}'.format(hname))
        print '\n'.join(self.hosts)
        self.validate_host_list(hosts)
        logger.info('Hosts validated, writing to file.')
        hfile = 'build_hosts-{}.csv'.format(id(self.hosts))
        self.write_to_csv(hfile, [ [e[0]] for e in self.hosts ])

    def nmap_scan(self):
        logger.info('Using NMAP to scan network for hosts.')
        scanNet = responses.ask('Please enter network to scan in the format x.x.x.x', 'ip')
        scanBit = responses.ask('Please enter subnet bits' 'num', xrange(256))
        network = '{}/{}'.format(scanNet, scanBit)
        logger.info('Scanning network: {}'.format(network))
        try:
            import nmap
            nm = nmap.PortScanner()
        except WindowsError:
            logger.error('Unable to initialize nmap.  Please installation.')
            sys.exit()
        nscan = nm.scan(hosts=network, arguments = '-sS -p 135')
        self.hosts = set(nscan['scan'][s]['hostname'] for s in nscan['scan']
                      if 'open' in nscan['scan'][s]['tcp'][135]['state'])
        if not self.hosts:
            logger.warning('No hosts found in scan.')
            self.options()
        else:
            logger.info('Hosts found.')
        nfile = 'nmap_scan_hosts-{}.csv'.format(id(self.hosts))
        self.write_to_csv(nfile, [ [e] for e in self.hosts ])

    def validate_host(self, host):
        logger.info('Testing connection to host: {}'.format(host))
        try:
            test = socket.create_connection((host, 135), 5)
        except socket.gaierror:
            logger.warning('DNS failed.')
            return False
        except socket.timeout:
            logger.warning('Connection timed out.')
            return False
        else:
            logger.info('Test successful.')
            test.close()
            return True

    def validate_host_list(self, hosts):
        logger.info('Validating list of hosts.')
        print 'Validating hosts.  Please stand by.\n'
        ip = re.compile(r'^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$')
        shortList = set()
        for host in hosts:
            print 'Checking {}...'.format(host),
            if re.match(ip, host):
                if any([e for e in re.match(ip, host).groups()
                       if int(e) > 255 or int(e) < 0]):
                    logger.warning('Invalid IP: {}'.format(host))
                    print ' invalid.'
                else:
                    logger.info('Adding IP: {}'.format(host))
                    print ' added.'
                    self.hosts.add((host, str(host)))
            elif set(host) - set('0123456789.') == set([]):
                logger.warning('Bad host: {}'.format(host))
                print ' invalid.'
            elif host.lower().split('.')[0] not in shortList:
                logger.info('Adding host: {}'.format(host))
                print ' added.'
                short = host.lower().split('.')[0]
                self.hosts.add((host, short))
                shortList.add(short)
            else:
                logger.warning('Duplicate host found: {}'.format(host))
                print ' duplicate.'
                self.dupes.append(host)
        if not responses.ask('\nAdd hosts with failed connections? (y/n)', 'bool'):
            logger.info('Checking connection to hosts.')
            print '\nChecking host connections:\n'
            for host in list(self.hosts):
                print '{}...'.format(host[0]),
                if not self.validate_host(host[0]):
                    logger.warning('Failed: {}'.format(host[0]))
                    print ' failed.'
                    self.hosts.remove(host)
                logger.warning('Success: {}'.format(host[0]))
                print ' success.'


    def write_to_csv(self, fn, data):
        logger.info('Writing to: {}'.format(fn))
        try:
            with open(fn, 'wb') as open_file:
                csv_out = csv.writer(open_file)
                csv_out.writerows(data)
        except IOError as e:
            logger.error('Error writing to file {}: {}'.format(fn, e))
        else:
            logger.info('File written.')

    def read_from_csv(self, fn):
        logger.info('Reading from: {}'.format(fn))
        try:
            with open(fn, 'rb') as open_file:
                csv_in = [e for e in csv.reader(open_file)]
        except IOError as e:
            logger.error('Error reading from file {}: {}'.format(fn, e))
            return False
        else:
            logger.info('File read.')
            return csv_in

    def __str__(self):
        hosts = '\n'.join([e[0] for e in self.hosts])
        return '\nThe following hosts will be scanned ({}):\n\n{}\n'.format(len(self.hosts), hosts)